# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import json
import tarfile
import json
import io
import pyarrow.parquet as pq
from io import BytesIO
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import tarfile
import json
import io
import wave
import numpy as np
import torchaudio
import os
import sys
import json
import random
import pickle
import argparse
import itertools
import mmap
import struct
import collections



import shutil
import multiprocessing as mp
from pathlib import Path

from tqdm import tqdm
from collections import defaultdict
from copy import deepcopy
from datetime import datetime
import pickle

from wids import wids
import math

torchaudio.set_audio_backend('soundfile')

AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])

try:
    MAIN_SPK_EMBEDDING=torch.load("/workspace/audio_checkpoints/flow_model/spk_embedding/0909/mean_embedding.pt")
    GPT_SPK_EMBEDDING=torch.load("/workspace/audio_checkpoints/flow_model/spk_embedding/0909/spk_mean_embeddings.pt")
except:
    MAIN_SPK_EMBEDDING=torch.zeros(1,192)
    GPT_SPK_EMBEDDING=torch.zeros(1,192)

def parquet_opener(data, mode='train', tts_data={}):
    """ Give url or local file, return file descriptor
        Inplace operation.

        Args:
            data(Iterable[str]): url or local file list

        Returns:
            Iterable[{src, stream}]
    """
    for sample in data:
        assert 'src' in sample
        url = sample['src']
        try:
            df = pq.read_table(url).to_pandas()
            for i in range(len(df)):
                if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
                    continue
                sample.update(dict(df.loc[i]))
                if mode == 'train':
                    # NOTE do not return sample directly, must initialize a new dict
                    yield {**sample}
                else:
                    for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
                        yield {**sample, 'tts_index': index, 'tts_text': text}
        except Exception as ex:
            logging.warning('Failed to open {}, ex info {}'.format(url, ex))




def parse_tar_header(header_bytes):
    header = struct.unpack("!100s8s8s8s12s12s8s1s100s6s2s32s32s8s8s155s", header_bytes)
    return TarHeader(*header)

TarHeader = collections.namedtuple(
    "TarHeader",
    [
        "name",
        "mode",
        "uid",
        "gid",
        "size",
        "mtime",
        "chksum",
        "typeflag",
        "linkname",
        "magic",
        "version",
        "uname",
        "gname",
        "devmajor",
        "devminor",
        "prefix",
    ],
)

class MMTar:
    def __init__(self, file_path: Path | str):
        self.stream = open(file_path, "rb")
        self.mmap = mmap.mmap(self.stream.fileno(), 0, access=mmap.ACCESS_READ)

    def __del__(self):
        try:
            self.mmap.close()
            self.stream.close()
        except:  # noqa
            pass

    def get_at_offset(self, offset) -> tuple[str, bytes]:
        header = parse_tar_header(self.mmap[offset : offset + 500])
        name = header.name.decode("utf-8").strip("\x00")
        start = offset + 512
        end = start + int(header.size.decode("utf-8")[:-1], 8)
        return name, self.mmap[start:end]


class Tar:
    def __init__(self, path: Path):
        self.tar = MMTar(path)
        indices_path = path.with_suffix(".index")
        self.index = pickle.loads(indices_path.read_bytes())
        self.name_mapping = {}
        for name, offset, _ in self.index:
            self.name_mapping[name] = offset

    def read(self, name: str) -> bytes:
        return self.tar.get_at_offset(self.name_mapping[name])[1]

def cosy_jsonl_opener(data, mode='train', tts_data={}):
    """ Give url or local file, return file descriptor
        Inplace operation.

        Args:
            data(Iterable[str]): url or local file list

        Returns:
            Iterable[{src, stream}]
    """
    for sample in data:
        assert 'src' in sample
        cosy_jsonl_path = sample['src']
        tar_file_path=cosy_jsonl_path.replace(".vq0907.jsonl",".tar")
        try:
            tar_data=Tar(Path(tar_file_path))
            with open(cosy_jsonl_path, 'r') as f:
                for line in f:
                    item=json.loads(line)
                    cosy_token = item['cosy_token']
                    sample['speech_token']=torch.tensor(cosy_token)
                    sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
                    # print(item['filename'])
                    yield {**sample}

        except Exception as ex:
            logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))


def cosy_jsonl_opener_vq0918_nopool(data, mode='train', tts_data={}):
    """ Give url or local file, return file descriptor
        Inplace operation.

        Args:
            data(Iterable[str]): url or local file list

        Returns:
            Iterable[{src, stream}]
    """
    for sample in data:
        assert 'src' in sample
        cosy_jsonl_path = sample['src']
        tar_file_path=cosy_jsonl_path.replace(".vq0918-nopool.jsonl",".tar")


        try:
            tar_data=Tar(Path(tar_file_path))
            with open(cosy_jsonl_path, 'r') as f:
                # cosy_data = [json.loads(line) for line in f]
                for line in f:
                    item=json.loads(line)
                    cosy_token = item['cosy_token']
                    sample['speech_token']=torch.tensor(cosy_token)
                    sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
                    # print(item['filename'])
                    yield {**sample}

        except Exception as ex:
            logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))



def cosy_jsonl_opener_vq0918_pool2(data, mode='train', tts_data={}):
    """ Give url or local file, return file descriptor
        Inplace operation.

        Args:
            data(Iterable[str]): url or local file list

        Returns:
            Iterable[{src, stream}]
    """
    for sample in data:
        assert 'src' in sample
        cosy_jsonl_path = sample['src']
        tar_file_path=cosy_jsonl_path.replace(".vq0918-pool2.jsonl",".tar")

        try:
            tar_data=Tar(Path(tar_file_path))
            with open(cosy_jsonl_path, 'r') as f:
                for line in f:
                    item=json.loads(line)
                    cosy_token = item['cosy_token']
                    sample['speech_token']=torch.tensor(cosy_token)
                    sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))

                    yield {**sample}

        except Exception as ex:
            logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))


def cosy_jsonl_opener_vq0918_pool4(data, mode='train', tts_data={}):
    """ Give url or local file, return file descriptor
        Inplace operation.

        Args:
            data(Iterable[str]): url or local file list

        Returns:
            Iterable[{src, stream}]
    """
    for sample in data:
        assert 'src' in sample
        cosy_jsonl_path = sample['src']
        tar_file_path=cosy_jsonl_path.replace(".vq0918-pool4.jsonl",".tar")
        try:
            tar_data=Tar(Path(tar_file_path))
            with open(cosy_jsonl_path, 'r') as f:
                # cosy_data = [json.loads(line) for line in f]
                for line in f:
                    item=json.loads(line)
                    cosy_token = item['cosy_token']
                    sample['speech_token']=torch.tensor(cosy_token)
                    sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
                    # print(item['filename'])
                    yield {**sample}

        except Exception as ex:
            logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))


def cosy_jsonl_opener_vq0918_pool8(data, mode='train', tts_data={}):
    """ Give url or local file, return file descriptor
        Inplace operation.

        Args:
            data(Iterable[str]): url or local file list

        Returns:
            Iterable[{src, stream}]
    """
    for sample in data:
        assert 'src' in sample
        cosy_jsonl_path = sample['src']
        tar_file_path=cosy_jsonl_path.replace(".vq0918-pool8.jsonl",".tar")

        try:
            tar_data=Tar(Path(tar_file_path))
            with open(cosy_jsonl_path, 'r') as f:
                # cosy_data = [json.loads(line) for line in f]
                for line in f:
                    item=json.loads(line)
                    cosy_token = item['cosy_token']
                    sample['speech_token']=torch.tensor(cosy_token)
                    sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
                    # print(item['filename'])
                    yield {**sample}

        except Exception as ex:
            logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
         


def process_sft_vq0918_pool4(data, mode='train', tts_data={}):
    for sample in data:
        assert 'src' in sample
        
        token_npy_path = sample['src']
        wav_path=token_npy_path.replace(".vq0918-pool4.npy","")

        # wav_path,token_npy_path=sample['src'].split(' ')
        try:
            sample['speech_token']=torch.tensor(np.load(token_npy_path))
            sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
            if sample['speech'].shape[0] > 1:
                sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
            sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
            yield {**sample}
        except Exception as ex:
            logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
            logging.warning('Failed to open {}'.format(wav_path))


def process_sft_vq0918_pool4_split(data, mode='train',split_token=25, tts_data={}):
    for sample in data:
        assert 'src' in sample
        
        token_npy_path = sample['src']
        wav_path=token_npy_path.replace(".vq0918-pool4.npy","")

        # wav_path,token_npy_path=sample['src'].split(' ')
        try:
            # sample['speech_token']=torch.tensor(np.load(token_npy_path))
            # sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
            # if sample['speech'].shape[0] > 1:
            #     sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
            
            
            # sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)


            speech_token=torch.tensor(np.load(token_npy_path))
            speech,sample_rate= torchaudio.load(wav_path)
            # split_speech=int(split_token / 12.5 * sample_rate)
            if speech.shape[0] > 1:
                speech = speech.mean(dim=0, keepdim=True)

            sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
            sample['sample_rate']=sample_rate

            num_splits = (speech_token.size(0) + split_token - 1) // split_token 

            for split_id in range(num_splits):
                end_token_idx = min((split_id + 1) * split_token, speech_token.size(0))
                end_speech_idx=int(np.ceil(end_token_idx / 12.5 * sample_rate))
                sample['speech_token']=speech_token[:end_token_idx]
                sample['speech']=speech[:,:end_speech_idx]
                print(sample['speech_token'].size(),sample['speech'].size())
                yield {**sample}
        except Exception as ex:
            logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
            logging.warning('Failed to open {}'.format(wav_path))


def process_sft_vq0918_pool2(data, mode='train', tts_data={}):
    for sample in data:
        assert 'src' in sample
        
        token_npy_path = sample['src'].replace(".vq0918-pool4.npy",".vq0918-pool2.npy")
        wav_path=token_npy_path.replace(".vq0918-pool2.npy","")

        # wav_path,token_npy_path=sample['src'].split(' ')
        try:
            sample['speech_token']=torch.tensor(np.load(token_npy_path))
            sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
            if sample['speech'].shape[0] > 1:
                sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)

            sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
            yield {**sample}
        except Exception as ex:
            logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
            logging.warning('Failed to open {}'.format(wav_path))
        

def process_sft_vq0918_pool2_split(data, mode='train',split_token=50, tts_data={}):
    for sample in data:
        assert 'src' in sample
        
        token_npy_path = sample['src']
        wav_path=token_npy_path.replace(".vq0918-pool2.npy","")

        # wav_path,token_npy_path=sample['src'].split(' ')
        try:
            # sample['speech_token']=torch.tensor(np.load(token_npy_path))
            # sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
            # if sample['speech'].shape[0] > 1:
            #     sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
            
            
            # sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)


            speech_token=torch.tensor(np.load(token_npy_path))
            speech,sample_rate= torchaudio.load(wav_path)
            # split_speech=int(split_token / 12.5 * sample_rate)
            if speech.shape[0] > 1:
                speech = speech.mean(dim=0, keepdim=True)

            sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
            sample['sample_rate']=sample_rate

            num_splits = (speech_token.size(0) + split_token - 1) // split_token 

            for split_id in range(num_splits):
                end_token_idx = min((split_id + 1) * split_token, speech_token.size(0))
                end_speech_idx=int(np.ceil(end_token_idx / 25 * sample_rate))
                sample['speech_token']=speech_token[:end_token_idx]
                sample['speech']=speech[:,:end_speech_idx]
                print(sample['speech_token'].size(),sample['speech'].size())
                yield {**sample}
        except Exception as ex:
            logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
            logging.warning('Failed to open {}'.format(wav_path))

def process_sft_vq0918_pool4_gpt(data, mode='train', tts_data={}):
    for sample in data:
        assert 'src' in sample
        try:
            entry=json.loads(sample['src'])
            sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
            
            for conv in entry["conversations"]:
                if "response_wav" in conv:
                    wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}"
                    token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy")
                    sample['speech_token']=torch.tensor(np.load(token_npy_path))
                    sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
                    if sample['speech'].shape[0] > 1:
                        sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
                    sample['spk_embedding']=spk_embedding
                    yield {**sample}
        except Exception as ex:
            # logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
            logging.warning('Failed to open {}'.format(wav_path))


def process_sft_vq0918_pool4_gpt_1010(data, mode='train', tts_data={}):
    for sample in data:
        assert 'src' in sample
        try:
            entry=json.loads(sample['src'])
            sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
            
            for conv in entry["conversations"]:
                if "response_wav" in conv:
                    wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}"
                    token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy")
                    sample['speech_token']=torch.tensor(np.load(token_npy_path))
                    sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
                    if sample['speech'].shape[0] > 1:
                        sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
                    sample['spk_embedding']=spk_embedding
                    yield {**sample}
                if "prompt_wav" in conv:
                    wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}"
                    token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy")
                    sample['speech_token']=torch.tensor(np.load(token_npy_path))
                    sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
                    if sample['speech'].shape[0] > 1:
                        sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
                    sample['spk_embedding']=spk_embedding
                    yield {**sample}
        except Exception as ex:
            # logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
            logging.warning('Failed to open {}'.format(wav_path))


def filter(data,
           max_length=10240,
           min_length=10,
           token_max_length=200,
           token_min_length=1,
           min_output_input_ratio=0.0005,
           max_output_input_ratio=1,
           mode='train'):
    """ Filter sample according to feature and label length
        Inplace operation.

        Args::
            data: Iterable[{key, wav, label, sample_rate}]
            max_length: drop utterance which is greater than max_length(10ms)
            min_length: drop utterance which is less than min_length(10ms)
            token_max_length: drop utterance which is greater than
                token_max_length, especially when use char unit for
                english modeling
            token_min_length: drop utterance which is
                less than token_max_length
            min_output_input_ratio: minimal ration of
                token_length / feats_length(10ms)
            max_output_input_ratio: maximum ration of
                token_length / feats_length(10ms)

        Returns:
            Iterable[{key, wav, label, sample_rate}]
    """
    for sample in data:
        # sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
        # del sample['audio_data']
        # sample['wav'] is torch.Tensor, we have 100 frames every second
        num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
        if num_frames < min_length:
            continue
        if num_frames > max_length:
            continue
        if len(sample['text_token']) < token_min_length:
            continue
        if len(sample['text_token']) > token_max_length:
            continue
        if len(sample['speech_token']) == 0:
            continue
        if num_frames != 0:
            if len(sample['text_token']) / num_frames < min_output_input_ratio:
                continue
            if len(sample['text_token']) / num_frames > max_output_input_ratio:
                continue
        yield sample

            
def filter_speech_token(data,
           max_length=10240,
           min_length=10,
           token_max_length=5000,
           token_min_length=1,
           min_output_input_ratio=0.0005,
           max_output_input_ratio=30,
           mode='train'):
    """ Filter sample according to feature and label length
        Inplace operation.

        Args::
            data: Iterable[{key, wav, label, sample_rate}]
            max_length: drop utterance which is greater than max_length(10ms)
            min_length: drop utterance which is less than min_length(10ms)
            token_max_length: drop utterance which is greater than
                token_max_length, especially when use char unit for
                english modeling
            token_min_length: drop utterance which is
                less than token_max_length
            min_output_input_ratio: minimal ration of
                token_length / feats_length(10ms)
            max_output_input_ratio: maximum ration of
                token_length / feats_length(10ms)

        Returns:
            Iterable[{key, wav, label, sample_rate}]
    """
    for sample in data:
        # sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
        # del sample['audio_data']
        # sample['wav'] is torch.Tensor, we have 100 frames every second
        num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
        if num_frames < min_length:
            continue
        if num_frames > max_length:
            continue
        if len(sample['speech_token']) < token_min_length:
            continue
        if len(sample['speech_token']) > token_max_length:
            continue
        if len(sample['speech_token']) == 0:
            continue
        if num_frames != 0:
            if len(sample['speech_token']) / num_frames < min_output_input_ratio:
                continue
            if len(sample['speech_token']) / num_frames > max_output_input_ratio:
                continue
        yield sample


def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
    """ Resample data.
        Inplace operation.

        Args:
            data: Iterable[{key, wav, label, sample_rate}]
            resample_rate: target resample rate

        Returns:
            Iterable[{key, wav, label, sample_rate}]
    """
    for sample in data:
        assert 'sample_rate' in sample
        assert 'speech' in sample
        sample_rate = sample['sample_rate']
        waveform = sample['speech']
        if sample_rate != resample_rate:
            if sample_rate < min_sample_rate:
                continue
            sample['sample_rate'] = resample_rate
            sample['speech'] = torchaudio.transforms.Resample(
                orig_freq=sample_rate, new_freq=resample_rate)(waveform)
        max_val = sample['speech'].abs().max()
        if max_val > 1:
            sample['speech'] /= max_val
        yield sample


def compute_fbank(data,
                  feat_extractor,
                  mode='train'):
    """ Extract fbank

        Args:
            data: Iterable[{key, wav, label, sample_rate}]

        Returns:
            Iterable[{key, feat, label}]
    """
    for sample in data:
        assert 'sample_rate' in sample
        assert 'speech' in sample
        # assert 'utt' in sample
        # assert 'text_token' in sample
        waveform = sample['speech']
        mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
        sample['speech_feat'] = mat
        del sample['speech']
        yield sample


def parse_embedding(data, normalize, mode='train'):
    """ Parse utt_embedding/spk_embedding

        Args:
            data: Iterable[{key, wav, label, sample_rate}]

        Returns:
            Iterable[{key, feat, label}]
    """
    for sample in data:
        sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
        sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
        if normalize:
            sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
            sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
        yield sample


def tokenize(data, get_tokenizer, allowed_special, mode='train'):
    """ Decode text to chars or BPE
        Inplace operation

        Args:
            data: Iterable[{key, wav, txt, sample_rate}]

        Returns:
            Iterable[{key, wav, txt, tokens, label, sample_rate}]
    """
    tokenizer = get_tokenizer()
    for sample in data:
        assert 'text' in sample
        sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
        if mode == 'inference':
            sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
        yield sample


def shuffle(data, shuffle_size=10000, mode='train'):
    """ Local shuffle the data

        Args:
            data: Iterable[{key, feat, label}]
            shuffle_size: buffer size for shuffle

        Returns:
            Iterable[{key, feat, label}]
    """
    buf = []
    for sample in data:
        buf.append(sample)
        if len(buf) >= shuffle_size:
            random.shuffle(buf)
            for x in buf:
                yield x
            buf = []
    # The sample left over
    random.shuffle(buf)
    for x in buf:
        yield x


def sort(data, sort_size=500, mode='train'):
    """ Sort the data by feature length.
        Sort is used after shuffle and before batch, so we can group
        utts with similar lengths into a batch, and `sort_size` should
        be less than `shuffle_size`

        Args:
            data: Iterable[{key, feat, label}]
            sort_size: buffer size for sort

        Returns:
            Iterable[{key, feat, label}]
    """

    buf = []
    for sample in data:
        buf.append(sample)
        if len(buf) >= sort_size:
            buf.sort(key=lambda x: x['speech_feat'].size(0))
            for x in buf:
                yield x
            buf = []
    # The sample left over
    buf.sort(key=lambda x: x['speech_feat'].size(0))
    for x in buf:
        yield x


def static_batch(data, batch_size=16):
    """ Static batch the data by `batch_size`

        Args:
            data: Iterable[{key, feat, label}]
            batch_size: batch size

        Returns:
            Iterable[List[{key, feat, label}]]
    """
    buf = []
    for sample in data:
        buf.append(sample)
        if len(buf) >= batch_size:
            yield buf
            buf = []
    if len(buf) > 0:
        yield buf


def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
    """ Dynamic batch the data until the total frames in batch
        reach `max_frames_in_batch`

        Args:
            data: Iterable[{key, feat, label}]
            max_frames_in_batch: max_frames in one batch

        Returns:
            Iterable[List[{key, feat, label}]]
    """
    buf = []
    longest_frames = 0
    for sample in data:
        assert 'speech_feat' in sample
        assert isinstance(sample['speech_feat'], torch.Tensor)
        new_sample_frames = sample['speech_feat'].size(0)
        longest_frames = max(longest_frames, new_sample_frames)
        frames_after_padding = longest_frames * (len(buf) + 1)
        if frames_after_padding > max_frames_in_batch:
            yield buf
            buf = [sample]
            longest_frames = new_sample_frames
        else:
            buf.append(sample)
    if len(buf) > 0:
        yield buf


def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
    """ Wrapper for static/dynamic batch
    """
    if mode == 'inference':
        return static_batch(data, 1)
    else:
        if batch_type == 'static':
            return static_batch(data, batch_size)
        elif batch_type == 'dynamic':
            return dynamic_batch(data, max_frames_in_batch)
        else:
            logging.fatal('Unsupported batch type {}'.format(batch_type))


def padding(data, use_spk_embedding, mode='train'):
    """ Padding the data into training data

        Args:
            data: Iterable[List[{key, feat, label}]]

        Returns:
            Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
    """
    for sample in data:
        assert isinstance(sample, list)
        speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
                                       dtype=torch.int32)
        order = torch.argsort(speech_feat_len, descending=True)

        utts = [sample[i]['utt'] for i in order]
        speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
        speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
        speech_token = pad_sequence(speech_token,
                                    batch_first=True,
                                    padding_value=0)
        speech_feat = [sample[i]['speech_feat'] for i in order]
        speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
        speech_feat = pad_sequence(speech_feat,
                                   batch_first=True,
                                   padding_value=0)
        text = [sample[i]['text'] for i in order]
        text_token = [torch.tensor(sample[i]['text_token']) for i in order]
        text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
        text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
        utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
        spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
        batch = {
            "utts": utts,
            "speech_token": speech_token,
            "speech_token_len": speech_token_len,
            "speech_feat": speech_feat,
            "speech_feat_len": speech_feat_len,
            "text": text,
            "text_token": text_token,
            "text_token_len": text_token_len,
            "utt_embedding": utt_embedding,
            "spk_embedding": spk_embedding,
        }
        if mode == 'inference':
            tts_text = [sample[i]['tts_text'] for i in order]
            tts_index = [sample[i]['tts_index'] for i in order]
            tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
            tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
            tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
            batch.update({'tts_text': tts_text,
                          'tts_index': tts_index,
                          'tts_text_token': tts_text_token,
                          'tts_text_token_len': tts_text_token_len})
        if use_spk_embedding is True:
            batch["embedding"] = batch["spk_embedding"]
        else:
            batch["embedding"] = batch["utt_embedding"]
        yield batch



def padding_speech_token(data, use_spk_embedding, mode='train'):
    """ Padding the data into training data

        Args:
            data: Iterable[List[{key, feat, label}]]

        Returns:
            Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
    """
    for sample in data:
        assert isinstance(sample, list)
        speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
                                       dtype=torch.int32)
        order = torch.argsort(speech_feat_len, descending=True)

        # utts = [sample[i]['utt'] for i in order]
        # speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
        try:
            speech_token = [sample[i]['speech_token'].clone().detach() for i in order]
            speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
            speech_token = pad_sequence(speech_token,
                                        batch_first=True,
                                        padding_value=0)
            speech_feat = [sample[i]['speech_feat'] for i in order]
            speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
            speech_feat = pad_sequence(speech_feat,
                                    batch_first=True,
                                    padding_value=0)
            batch = {
                "speech_token": speech_token,
                "speech_token_len": speech_token_len,
                "speech_feat": speech_feat,
                "speech_feat_len": speech_feat_len,
            }
            if mode == 'inference':
                tts_text = [sample[i]['tts_text'] for i in order]
                tts_index = [sample[i]['tts_index'] for i in order]
                tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
                tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
                tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
                batch.update({'tts_text': tts_text,
                            'tts_index': tts_index,
                            'tts_text_token': tts_text_token,
                            'tts_text_token_len': tts_text_token_len})
            # if use_spk_embedding is True:
            #     batch["embedding"] = batch["spk_embedding"]
            # else:
            #     batch["embedding"] = batch["utt_embedding"]
            batch["embedding"]=torch.zeros((batch["speech_feat"].size(0),192),device=batch["speech_feat"].device)
            yield batch
        except Exception as ex:
            logging.warning(' ex info {}'.format(ex))
            # assert False



def padding_speech_token_spk(data, use_spk_embedding, mode='train'):
    """ Padding the data into training data

        Args:
            data: Iterable[List[{key, feat, label}]]

        Returns:
            Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
    """
    for sample in data:
        assert isinstance(sample, list)
        speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
                                       dtype=torch.int32)
        order = torch.argsort(speech_feat_len, descending=True)

        # utts = [sample[i]['utt'] for i in order]
        # speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
        try:
            speech_token = [sample[i]['speech_token'].clone().detach() for i in order]
            speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
            speech_token = pad_sequence(speech_token,
                                        batch_first=True,
                                        padding_value=0)
            speech_feat = [sample[i]['speech_feat'] for i in order]
            speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
            speech_feat = pad_sequence(speech_feat,
                                    batch_first=True,
                                    padding_value=0)
            spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
            batch = {
                "speech_token": speech_token,
                "speech_token_len": speech_token_len,
                "speech_feat": speech_feat,
                "speech_feat_len": speech_feat_len,
                "spk_embedding": spk_embedding,
            }
            if mode == 'inference':
                tts_text = [sample[i]['tts_text'] for i in order]
                tts_index = [sample[i]['tts_index'] for i in order]
                tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
                tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
                tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
                batch.update({'tts_text': tts_text,
                            'tts_index': tts_index,
                            'tts_text_token': tts_text_token,
                            'tts_text_token_len': tts_text_token_len})
            # if use_spk_embedding is True:
            #     batch["embedding"] = batch["spk_embedding"]
            # else:
            #     batch["embedding"] = batch["utt_embedding"]
            # batch["embedding"]=torch.zeros((batch["speech_feat"].size(0),192),device=batch["speech_feat"].device)
            batch["embedding"] = batch["spk_embedding"]
            yield batch
        except Exception as ex:
            logging.warning(' ex info {}'.format(ex))
            # assert False